{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nupic.research.frameworks.dynamic_sparse.networks import *\n",
    "from nupic.research.frameworks.dynamic_sparse.models import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = dict(\n",
    "    device=\"cpu\",\n",
    "    network=\"GSCHeb\",\n",
    "    optim_alg=\"SGD\",\n",
    "    momentum=0,  # 0.9,\n",
    "    learning_rate=0.01,  # 0.1,\n",
    "    weight_decay=0.01,  # 1e-4,\n",
    "    lr_scheduler=\"MultiStepLR\",\n",
    "    lr_milestones=[30, 60, 90],\n",
    "    lr_gamma=0.9,  # 0.1,\n",
    "    use_kwinners=True,\n",
    "    model=\"DSNNWeightedMag\",\n",
    "    on_perc=0.04,\n",
    "    weight_prune_perc=0.3,\n",
    "\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'name': 'Conv2d', 'index': 2, 'shape': torch.Size([64, 1, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64.0},\n",
       " {'name': 'Conv2d', 'index': 6, 'shape': torch.Size([64, 64, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 4096.0},\n",
       " {'name': 'DSLinear', 'index': 13, 'shape': torch.Size([1000, 1600]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64000.0},\n",
       " {'name': 'DSLinear', 'index': 17, 'shape': torch.Size([12, 1000]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 480.0}]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network = GSCHeb(config)\n",
    "model = DSNNWeightedMag(network, config)\n",
    "model.setup()\n",
    "model.sparse_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        ...,\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modules = model.sparse_modules\n",
    "modules[0].get_coactivations()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'name': 'Conv2d', 'index': 2, 'shape': torch.Size([64, 1, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64.0},\n",
       " {'name': 'Conv2d', 'index': 6, 'shape': torch.Size([64, 64, 5, 5]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 4096.0},\n",
       " {'name': 'Linear', 'index': 12, 'shape': torch.Size([1000, 1600]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 64000.0},\n",
       " {'name': 'Linear', 'index': 15, 'shape': torch.Size([12, 1000]), 'on_perc': 0.04, 'hebbian_prune': None, 'weight_prune': 0.3, 'num_params': 480.0}]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "network = GSCHeb_v0(config)\n",
    "model = DSNNWeightedMag(network, config)\n",
    "model.setup()\n",
    "model.sparse_modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        ...,\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]],\n",
       "\n",
       "\n",
       "        [[[0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.],\n",
       "          [0., 0., 0., 0., 0.]]]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modules = model.sparse_modules\n",
    "modules[0].m.coactivations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = resnet152()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 16, 32, 32]             448\n",
      "       BatchNorm2d-2           [-1, 16, 32, 32]              32\n",
      "              ReLU-3           [-1, 16, 32, 32]               0\n",
      "            Conv2d-4           [-1, 16, 32, 32]             272\n",
      "       BatchNorm2d-5           [-1, 16, 32, 32]              32\n",
      "              ReLU-6           [-1, 16, 32, 32]               0\n",
      "            Conv2d-7           [-1, 16, 32, 32]           2,320\n",
      "       BatchNorm2d-8           [-1, 16, 32, 32]              32\n",
      "              ReLU-9           [-1, 16, 32, 32]               0\n",
      "           Conv2d-10           [-1, 64, 32, 32]           1,088\n",
      "      BatchNorm2d-11           [-1, 64, 32, 32]             128\n",
      "           Conv2d-12           [-1, 64, 32, 32]           1,088\n",
      "      BatchNorm2d-13           [-1, 64, 32, 32]             128\n",
      "             ReLU-14           [-1, 64, 32, 32]               0\n",
      "       Bottleneck-15           [-1, 64, 32, 32]               0\n",
      "           Conv2d-16           [-1, 16, 32, 32]           1,040\n",
      "      BatchNorm2d-17           [-1, 16, 32, 32]              32\n",
      "             ReLU-18           [-1, 16, 32, 32]               0\n",
      "           Conv2d-19           [-1, 16, 32, 32]           2,320\n",
      "      BatchNorm2d-20           [-1, 16, 32, 32]              32\n",
      "             ReLU-21           [-1, 16, 32, 32]               0\n",
      "           Conv2d-22           [-1, 64, 32, 32]           1,088\n",
      "      BatchNorm2d-23           [-1, 64, 32, 32]             128\n",
      "             ReLU-24           [-1, 64, 32, 32]               0\n",
      "       Bottleneck-25           [-1, 64, 32, 32]               0\n",
      "           Conv2d-26           [-1, 16, 32, 32]           1,040\n",
      "      BatchNorm2d-27           [-1, 16, 32, 32]              32\n",
      "             ReLU-28           [-1, 16, 32, 32]               0\n",
      "           Conv2d-29           [-1, 16, 32, 32]           2,320\n",
      "      BatchNorm2d-30           [-1, 16, 32, 32]              32\n",
      "             ReLU-31           [-1, 16, 32, 32]               0\n",
      "           Conv2d-32           [-1, 64, 32, 32]           1,088\n",
      "      BatchNorm2d-33           [-1, 64, 32, 32]             128\n",
      "             ReLU-34           [-1, 64, 32, 32]               0\n",
      "       Bottleneck-35           [-1, 64, 32, 32]               0\n",
      "           Conv2d-36           [-1, 32, 32, 32]           2,080\n",
      "      BatchNorm2d-37           [-1, 32, 32, 32]              64\n",
      "             ReLU-38           [-1, 32, 32, 32]               0\n",
      "           Conv2d-39           [-1, 32, 16, 16]           9,248\n",
      "      BatchNorm2d-40           [-1, 32, 16, 16]              64\n",
      "             ReLU-41           [-1, 32, 16, 16]               0\n",
      "           Conv2d-42          [-1, 128, 16, 16]           4,224\n",
      "      BatchNorm2d-43          [-1, 128, 16, 16]             256\n",
      "           Conv2d-44          [-1, 128, 16, 16]           8,320\n",
      "      BatchNorm2d-45          [-1, 128, 16, 16]             256\n",
      "             ReLU-46          [-1, 128, 16, 16]               0\n",
      "       Bottleneck-47          [-1, 128, 16, 16]               0\n",
      "           Conv2d-48           [-1, 32, 16, 16]           4,128\n",
      "      BatchNorm2d-49           [-1, 32, 16, 16]              64\n",
      "             ReLU-50           [-1, 32, 16, 16]               0\n",
      "           Conv2d-51           [-1, 32, 16, 16]           9,248\n",
      "      BatchNorm2d-52           [-1, 32, 16, 16]              64\n",
      "             ReLU-53           [-1, 32, 16, 16]               0\n",
      "           Conv2d-54          [-1, 128, 16, 16]           4,224\n",
      "      BatchNorm2d-55          [-1, 128, 16, 16]             256\n",
      "             ReLU-56          [-1, 128, 16, 16]               0\n",
      "       Bottleneck-57          [-1, 128, 16, 16]               0\n",
      "           Conv2d-58           [-1, 32, 16, 16]           4,128\n",
      "      BatchNorm2d-59           [-1, 32, 16, 16]              64\n",
      "             ReLU-60           [-1, 32, 16, 16]               0\n",
      "           Conv2d-61           [-1, 32, 16, 16]           9,248\n",
      "      BatchNorm2d-62           [-1, 32, 16, 16]              64\n",
      "             ReLU-63           [-1, 32, 16, 16]               0\n",
      "           Conv2d-64          [-1, 128, 16, 16]           4,224\n",
      "      BatchNorm2d-65          [-1, 128, 16, 16]             256\n",
      "             ReLU-66          [-1, 128, 16, 16]               0\n",
      "       Bottleneck-67          [-1, 128, 16, 16]               0\n",
      "           Conv2d-68           [-1, 32, 16, 16]           4,128\n",
      "      BatchNorm2d-69           [-1, 32, 16, 16]              64\n",
      "             ReLU-70           [-1, 32, 16, 16]               0\n",
      "           Conv2d-71           [-1, 32, 16, 16]           9,248\n",
      "      BatchNorm2d-72           [-1, 32, 16, 16]              64\n",
      "             ReLU-73           [-1, 32, 16, 16]               0\n",
      "           Conv2d-74          [-1, 128, 16, 16]           4,224\n",
      "      BatchNorm2d-75          [-1, 128, 16, 16]             256\n",
      "             ReLU-76          [-1, 128, 16, 16]               0\n",
      "       Bottleneck-77          [-1, 128, 16, 16]               0\n",
      "           Conv2d-78           [-1, 32, 16, 16]           4,128\n",
      "      BatchNorm2d-79           [-1, 32, 16, 16]              64\n",
      "             ReLU-80           [-1, 32, 16, 16]               0\n",
      "           Conv2d-81           [-1, 32, 16, 16]           9,248\n",
      "      BatchNorm2d-82           [-1, 32, 16, 16]              64\n",
      "             ReLU-83           [-1, 32, 16, 16]               0\n",
      "           Conv2d-84          [-1, 128, 16, 16]           4,224\n",
      "      BatchNorm2d-85          [-1, 128, 16, 16]             256\n",
      "             ReLU-86          [-1, 128, 16, 16]               0\n",
      "       Bottleneck-87          [-1, 128, 16, 16]               0\n",
      "           Conv2d-88           [-1, 32, 16, 16]           4,128\n",
      "      BatchNorm2d-89           [-1, 32, 16, 16]              64\n",
      "             ReLU-90           [-1, 32, 16, 16]               0\n",
      "           Conv2d-91           [-1, 32, 16, 16]           9,248\n",
      "      BatchNorm2d-92           [-1, 32, 16, 16]              64\n",
      "             ReLU-93           [-1, 32, 16, 16]               0\n",
      "           Conv2d-94          [-1, 128, 16, 16]           4,224\n",
      "      BatchNorm2d-95          [-1, 128, 16, 16]             256\n",
      "             ReLU-96          [-1, 128, 16, 16]               0\n",
      "       Bottleneck-97          [-1, 128, 16, 16]               0\n",
      "           Conv2d-98           [-1, 32, 16, 16]           4,128\n",
      "      BatchNorm2d-99           [-1, 32, 16, 16]              64\n",
      "            ReLU-100           [-1, 32, 16, 16]               0\n",
      "          Conv2d-101           [-1, 32, 16, 16]           9,248\n",
      "     BatchNorm2d-102           [-1, 32, 16, 16]              64\n",
      "            ReLU-103           [-1, 32, 16, 16]               0\n",
      "          Conv2d-104          [-1, 128, 16, 16]           4,224\n",
      "     BatchNorm2d-105          [-1, 128, 16, 16]             256\n",
      "            ReLU-106          [-1, 128, 16, 16]               0\n",
      "      Bottleneck-107          [-1, 128, 16, 16]               0\n",
      "          Conv2d-108           [-1, 32, 16, 16]           4,128\n",
      "     BatchNorm2d-109           [-1, 32, 16, 16]              64\n",
      "            ReLU-110           [-1, 32, 16, 16]               0\n",
      "          Conv2d-111           [-1, 32, 16, 16]           9,248\n",
      "     BatchNorm2d-112           [-1, 32, 16, 16]              64\n",
      "            ReLU-113           [-1, 32, 16, 16]               0\n",
      "          Conv2d-114          [-1, 128, 16, 16]           4,224\n",
      "     BatchNorm2d-115          [-1, 128, 16, 16]             256\n",
      "            ReLU-116          [-1, 128, 16, 16]               0\n",
      "      Bottleneck-117          [-1, 128, 16, 16]               0\n",
      "          Conv2d-118           [-1, 64, 16, 16]           8,256\n",
      "     BatchNorm2d-119           [-1, 64, 16, 16]             128\n",
      "            ReLU-120           [-1, 64, 16, 16]               0\n",
      "          Conv2d-121             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-122             [-1, 64, 8, 8]             128\n",
      "            ReLU-123             [-1, 64, 8, 8]               0\n",
      "          Conv2d-124            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-125            [-1, 256, 8, 8]             512\n",
      "          Conv2d-126            [-1, 256, 8, 8]          33,024\n",
      "     BatchNorm2d-127            [-1, 256, 8, 8]             512\n",
      "            ReLU-128            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-129            [-1, 256, 8, 8]               0\n",
      "          Conv2d-130             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-131             [-1, 64, 8, 8]             128\n",
      "            ReLU-132             [-1, 64, 8, 8]               0\n",
      "          Conv2d-133             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-134             [-1, 64, 8, 8]             128\n",
      "            ReLU-135             [-1, 64, 8, 8]               0\n",
      "          Conv2d-136            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-137            [-1, 256, 8, 8]             512\n",
      "            ReLU-138            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-139            [-1, 256, 8, 8]               0\n",
      "          Conv2d-140             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-141             [-1, 64, 8, 8]             128\n",
      "            ReLU-142             [-1, 64, 8, 8]               0\n",
      "          Conv2d-143             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-144             [-1, 64, 8, 8]             128\n",
      "            ReLU-145             [-1, 64, 8, 8]               0\n",
      "          Conv2d-146            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-147            [-1, 256, 8, 8]             512\n",
      "            ReLU-148            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-149            [-1, 256, 8, 8]               0\n",
      "          Conv2d-150             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-151             [-1, 64, 8, 8]             128\n",
      "            ReLU-152             [-1, 64, 8, 8]               0\n",
      "          Conv2d-153             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-154             [-1, 64, 8, 8]             128\n",
      "            ReLU-155             [-1, 64, 8, 8]               0\n",
      "          Conv2d-156            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-157            [-1, 256, 8, 8]             512\n",
      "            ReLU-158            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-159            [-1, 256, 8, 8]               0\n",
      "          Conv2d-160             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-161             [-1, 64, 8, 8]             128\n",
      "            ReLU-162             [-1, 64, 8, 8]               0\n",
      "          Conv2d-163             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-164             [-1, 64, 8, 8]             128\n",
      "            ReLU-165             [-1, 64, 8, 8]               0\n",
      "          Conv2d-166            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-167            [-1, 256, 8, 8]             512\n",
      "            ReLU-168            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-169            [-1, 256, 8, 8]               0\n",
      "          Conv2d-170             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-171             [-1, 64, 8, 8]             128\n",
      "            ReLU-172             [-1, 64, 8, 8]               0\n",
      "          Conv2d-173             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-174             [-1, 64, 8, 8]             128\n",
      "            ReLU-175             [-1, 64, 8, 8]               0\n",
      "          Conv2d-176            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-177            [-1, 256, 8, 8]             512\n",
      "            ReLU-178            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-179            [-1, 256, 8, 8]               0\n",
      "          Conv2d-180             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-181             [-1, 64, 8, 8]             128\n",
      "            ReLU-182             [-1, 64, 8, 8]               0\n",
      "          Conv2d-183             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-184             [-1, 64, 8, 8]             128\n",
      "            ReLU-185             [-1, 64, 8, 8]               0\n",
      "          Conv2d-186            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-187            [-1, 256, 8, 8]             512\n",
      "            ReLU-188            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-189            [-1, 256, 8, 8]               0\n",
      "          Conv2d-190             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-191             [-1, 64, 8, 8]             128\n",
      "            ReLU-192             [-1, 64, 8, 8]               0\n",
      "          Conv2d-193             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-194             [-1, 64, 8, 8]             128\n",
      "            ReLU-195             [-1, 64, 8, 8]               0\n",
      "          Conv2d-196            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-197            [-1, 256, 8, 8]             512\n",
      "            ReLU-198            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-199            [-1, 256, 8, 8]               0\n",
      "          Conv2d-200             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-201             [-1, 64, 8, 8]             128\n",
      "            ReLU-202             [-1, 64, 8, 8]               0\n",
      "          Conv2d-203             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-204             [-1, 64, 8, 8]             128\n",
      "            ReLU-205             [-1, 64, 8, 8]               0\n",
      "          Conv2d-206            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-207            [-1, 256, 8, 8]             512\n",
      "            ReLU-208            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-209            [-1, 256, 8, 8]               0\n",
      "          Conv2d-210             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-211             [-1, 64, 8, 8]             128\n",
      "            ReLU-212             [-1, 64, 8, 8]               0\n",
      "          Conv2d-213             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-214             [-1, 64, 8, 8]             128\n",
      "            ReLU-215             [-1, 64, 8, 8]               0\n",
      "          Conv2d-216            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-217            [-1, 256, 8, 8]             512\n",
      "            ReLU-218            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-219            [-1, 256, 8, 8]               0\n",
      "          Conv2d-220             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-221             [-1, 64, 8, 8]             128\n",
      "            ReLU-222             [-1, 64, 8, 8]               0\n",
      "          Conv2d-223             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-224             [-1, 64, 8, 8]             128\n",
      "            ReLU-225             [-1, 64, 8, 8]               0\n",
      "          Conv2d-226            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-227            [-1, 256, 8, 8]             512\n",
      "            ReLU-228            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-229            [-1, 256, 8, 8]               0\n",
      "          Conv2d-230             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-231             [-1, 64, 8, 8]             128\n",
      "            ReLU-232             [-1, 64, 8, 8]               0\n",
      "          Conv2d-233             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-234             [-1, 64, 8, 8]             128\n",
      "            ReLU-235             [-1, 64, 8, 8]               0\n",
      "          Conv2d-236            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-237            [-1, 256, 8, 8]             512\n",
      "            ReLU-238            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-239            [-1, 256, 8, 8]               0\n",
      "          Conv2d-240             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-241             [-1, 64, 8, 8]             128\n",
      "            ReLU-242             [-1, 64, 8, 8]               0\n",
      "          Conv2d-243             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-244             [-1, 64, 8, 8]             128\n",
      "            ReLU-245             [-1, 64, 8, 8]               0\n",
      "          Conv2d-246            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-247            [-1, 256, 8, 8]             512\n",
      "            ReLU-248            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-249            [-1, 256, 8, 8]               0\n",
      "          Conv2d-250             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-251             [-1, 64, 8, 8]             128\n",
      "            ReLU-252             [-1, 64, 8, 8]               0\n",
      "          Conv2d-253             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-254             [-1, 64, 8, 8]             128\n",
      "            ReLU-255             [-1, 64, 8, 8]               0\n",
      "          Conv2d-256            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-257            [-1, 256, 8, 8]             512\n",
      "            ReLU-258            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-259            [-1, 256, 8, 8]               0\n",
      "          Conv2d-260             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-261             [-1, 64, 8, 8]             128\n",
      "            ReLU-262             [-1, 64, 8, 8]               0\n",
      "          Conv2d-263             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-264             [-1, 64, 8, 8]             128\n",
      "            ReLU-265             [-1, 64, 8, 8]               0\n",
      "          Conv2d-266            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-267            [-1, 256, 8, 8]             512\n",
      "            ReLU-268            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-269            [-1, 256, 8, 8]               0\n",
      "          Conv2d-270             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-271             [-1, 64, 8, 8]             128\n",
      "            ReLU-272             [-1, 64, 8, 8]               0\n",
      "          Conv2d-273             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-274             [-1, 64, 8, 8]             128\n",
      "            ReLU-275             [-1, 64, 8, 8]               0\n",
      "          Conv2d-276            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-277            [-1, 256, 8, 8]             512\n",
      "            ReLU-278            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-279            [-1, 256, 8, 8]               0\n",
      "          Conv2d-280             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-281             [-1, 64, 8, 8]             128\n",
      "            ReLU-282             [-1, 64, 8, 8]               0\n",
      "          Conv2d-283             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-284             [-1, 64, 8, 8]             128\n",
      "            ReLU-285             [-1, 64, 8, 8]               0\n",
      "          Conv2d-286            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-287            [-1, 256, 8, 8]             512\n",
      "            ReLU-288            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-289            [-1, 256, 8, 8]               0\n",
      "          Conv2d-290             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-291             [-1, 64, 8, 8]             128\n",
      "            ReLU-292             [-1, 64, 8, 8]               0\n",
      "          Conv2d-293             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-294             [-1, 64, 8, 8]             128\n",
      "            ReLU-295             [-1, 64, 8, 8]               0\n",
      "          Conv2d-296            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-297            [-1, 256, 8, 8]             512\n",
      "            ReLU-298            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-299            [-1, 256, 8, 8]               0\n",
      "          Conv2d-300             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-301             [-1, 64, 8, 8]             128\n",
      "            ReLU-302             [-1, 64, 8, 8]               0\n",
      "          Conv2d-303             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-304             [-1, 64, 8, 8]             128\n",
      "            ReLU-305             [-1, 64, 8, 8]               0\n",
      "          Conv2d-306            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-307            [-1, 256, 8, 8]             512\n",
      "            ReLU-308            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-309            [-1, 256, 8, 8]               0\n",
      "          Conv2d-310             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-311             [-1, 64, 8, 8]             128\n",
      "            ReLU-312             [-1, 64, 8, 8]               0\n",
      "          Conv2d-313             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-314             [-1, 64, 8, 8]             128\n",
      "            ReLU-315             [-1, 64, 8, 8]               0\n",
      "          Conv2d-316            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-317            [-1, 256, 8, 8]             512\n",
      "            ReLU-318            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-319            [-1, 256, 8, 8]               0\n",
      "          Conv2d-320             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-321             [-1, 64, 8, 8]             128\n",
      "            ReLU-322             [-1, 64, 8, 8]               0\n",
      "          Conv2d-323             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-324             [-1, 64, 8, 8]             128\n",
      "            ReLU-325             [-1, 64, 8, 8]               0\n",
      "          Conv2d-326            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-327            [-1, 256, 8, 8]             512\n",
      "            ReLU-328            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-329            [-1, 256, 8, 8]               0\n",
      "          Conv2d-330             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-331             [-1, 64, 8, 8]             128\n",
      "            ReLU-332             [-1, 64, 8, 8]               0\n",
      "          Conv2d-333             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-334             [-1, 64, 8, 8]             128\n",
      "            ReLU-335             [-1, 64, 8, 8]               0\n",
      "          Conv2d-336            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-337            [-1, 256, 8, 8]             512\n",
      "            ReLU-338            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-339            [-1, 256, 8, 8]               0\n",
      "          Conv2d-340             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-341             [-1, 64, 8, 8]             128\n",
      "            ReLU-342             [-1, 64, 8, 8]               0\n",
      "          Conv2d-343             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-344             [-1, 64, 8, 8]             128\n",
      "            ReLU-345             [-1, 64, 8, 8]               0\n",
      "          Conv2d-346            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-347            [-1, 256, 8, 8]             512\n",
      "            ReLU-348            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-349            [-1, 256, 8, 8]               0\n",
      "          Conv2d-350             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-351             [-1, 64, 8, 8]             128\n",
      "            ReLU-352             [-1, 64, 8, 8]               0\n",
      "          Conv2d-353             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-354             [-1, 64, 8, 8]             128\n",
      "            ReLU-355             [-1, 64, 8, 8]               0\n",
      "          Conv2d-356            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-357            [-1, 256, 8, 8]             512\n",
      "            ReLU-358            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-359            [-1, 256, 8, 8]               0\n",
      "          Conv2d-360             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-361             [-1, 64, 8, 8]             128\n",
      "            ReLU-362             [-1, 64, 8, 8]               0\n",
      "          Conv2d-363             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-364             [-1, 64, 8, 8]             128\n",
      "            ReLU-365             [-1, 64, 8, 8]               0\n",
      "          Conv2d-366            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-367            [-1, 256, 8, 8]             512\n",
      "            ReLU-368            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-369            [-1, 256, 8, 8]               0\n",
      "          Conv2d-370             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-371             [-1, 64, 8, 8]             128\n",
      "            ReLU-372             [-1, 64, 8, 8]               0\n",
      "          Conv2d-373             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-374             [-1, 64, 8, 8]             128\n",
      "            ReLU-375             [-1, 64, 8, 8]               0\n",
      "          Conv2d-376            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-377            [-1, 256, 8, 8]             512\n",
      "            ReLU-378            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-379            [-1, 256, 8, 8]               0\n",
      "          Conv2d-380             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-381             [-1, 64, 8, 8]             128\n",
      "            ReLU-382             [-1, 64, 8, 8]               0\n",
      "          Conv2d-383             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-384             [-1, 64, 8, 8]             128\n",
      "            ReLU-385             [-1, 64, 8, 8]               0\n",
      "          Conv2d-386            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-387            [-1, 256, 8, 8]             512\n",
      "            ReLU-388            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-389            [-1, 256, 8, 8]               0\n",
      "          Conv2d-390             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-391             [-1, 64, 8, 8]             128\n",
      "            ReLU-392             [-1, 64, 8, 8]               0\n",
      "          Conv2d-393             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-394             [-1, 64, 8, 8]             128\n",
      "            ReLU-395             [-1, 64, 8, 8]               0\n",
      "          Conv2d-396            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-397            [-1, 256, 8, 8]             512\n",
      "            ReLU-398            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-399            [-1, 256, 8, 8]               0\n",
      "          Conv2d-400             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-401             [-1, 64, 8, 8]             128\n",
      "            ReLU-402             [-1, 64, 8, 8]               0\n",
      "          Conv2d-403             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-404             [-1, 64, 8, 8]             128\n",
      "            ReLU-405             [-1, 64, 8, 8]               0\n",
      "          Conv2d-406            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-407            [-1, 256, 8, 8]             512\n",
      "            ReLU-408            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-409            [-1, 256, 8, 8]               0\n",
      "          Conv2d-410             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-411             [-1, 64, 8, 8]             128\n",
      "            ReLU-412             [-1, 64, 8, 8]               0\n",
      "          Conv2d-413             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-414             [-1, 64, 8, 8]             128\n",
      "            ReLU-415             [-1, 64, 8, 8]               0\n",
      "          Conv2d-416            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-417            [-1, 256, 8, 8]             512\n",
      "            ReLU-418            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-419            [-1, 256, 8, 8]               0\n",
      "          Conv2d-420             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-421             [-1, 64, 8, 8]             128\n",
      "            ReLU-422             [-1, 64, 8, 8]               0\n",
      "          Conv2d-423             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-424             [-1, 64, 8, 8]             128\n",
      "            ReLU-425             [-1, 64, 8, 8]               0\n",
      "          Conv2d-426            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-427            [-1, 256, 8, 8]             512\n",
      "            ReLU-428            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-429            [-1, 256, 8, 8]               0\n",
      "          Conv2d-430             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-431             [-1, 64, 8, 8]             128\n",
      "            ReLU-432             [-1, 64, 8, 8]               0\n",
      "          Conv2d-433             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-434             [-1, 64, 8, 8]             128\n",
      "            ReLU-435             [-1, 64, 8, 8]               0\n",
      "          Conv2d-436            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-437            [-1, 256, 8, 8]             512\n",
      "            ReLU-438            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-439            [-1, 256, 8, 8]               0\n",
      "          Conv2d-440             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-441             [-1, 64, 8, 8]             128\n",
      "            ReLU-442             [-1, 64, 8, 8]               0\n",
      "          Conv2d-443             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-444             [-1, 64, 8, 8]             128\n",
      "            ReLU-445             [-1, 64, 8, 8]               0\n",
      "          Conv2d-446            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-447            [-1, 256, 8, 8]             512\n",
      "            ReLU-448            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-449            [-1, 256, 8, 8]               0\n",
      "          Conv2d-450             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-451             [-1, 64, 8, 8]             128\n",
      "            ReLU-452             [-1, 64, 8, 8]               0\n",
      "          Conv2d-453             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-454             [-1, 64, 8, 8]             128\n",
      "            ReLU-455             [-1, 64, 8, 8]               0\n",
      "          Conv2d-456            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-457            [-1, 256, 8, 8]             512\n",
      "            ReLU-458            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-459            [-1, 256, 8, 8]               0\n",
      "          Conv2d-460             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-461             [-1, 64, 8, 8]             128\n",
      "            ReLU-462             [-1, 64, 8, 8]               0\n",
      "          Conv2d-463             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-464             [-1, 64, 8, 8]             128\n",
      "            ReLU-465             [-1, 64, 8, 8]               0\n",
      "          Conv2d-466            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-467            [-1, 256, 8, 8]             512\n",
      "            ReLU-468            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-469            [-1, 256, 8, 8]               0\n",
      "          Conv2d-470             [-1, 64, 8, 8]          16,448\n",
      "     BatchNorm2d-471             [-1, 64, 8, 8]             128\n",
      "            ReLU-472             [-1, 64, 8, 8]               0\n",
      "          Conv2d-473             [-1, 64, 8, 8]          36,928\n",
      "     BatchNorm2d-474             [-1, 64, 8, 8]             128\n",
      "            ReLU-475             [-1, 64, 8, 8]               0\n",
      "          Conv2d-476            [-1, 256, 8, 8]          16,640\n",
      "     BatchNorm2d-477            [-1, 256, 8, 8]             512\n",
      "            ReLU-478            [-1, 256, 8, 8]               0\n",
      "      Bottleneck-479            [-1, 256, 8, 8]               0\n",
      "AdaptiveAvgPool2d-480            [-1, 256, 1, 1]               0\n",
      "         Flatten-481                  [-1, 256]               0\n",
      "          Linear-482                   [-1, 10]           2,570\n",
      "================================================================\n",
      "Total params: 2,741,386\n",
      "Trainable params: 2,741,386\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.01\n",
      "Forward/backward pass size (MB): 46.97\n",
      "Params size (MB): 10.46\n",
      "Estimated Total Size (MB): 57.44\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(resnet152(), input_size=(3,32,32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Wide-Resnet 28x2\n",
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 16, 32, 32]             448\n",
      "       BatchNorm2d-2           [-1, 16, 32, 32]              32\n",
      "              ReLU-3           [-1, 16, 32, 32]               0\n",
      "            Conv2d-4           [-1, 32, 32, 32]           4,640\n",
      "           Dropout-5           [-1, 32, 32, 32]               0\n",
      "       BatchNorm2d-6           [-1, 32, 32, 32]              64\n",
      "              ReLU-7           [-1, 32, 32, 32]               0\n",
      "            Conv2d-8           [-1, 32, 32, 32]           9,248\n",
      "            Conv2d-9           [-1, 32, 32, 32]             544\n",
      "        WideBasic-10           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-11           [-1, 32, 32, 32]              64\n",
      "             ReLU-12           [-1, 32, 32, 32]               0\n",
      "           Conv2d-13           [-1, 32, 32, 32]           9,248\n",
      "          Dropout-14           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-15           [-1, 32, 32, 32]              64\n",
      "             ReLU-16           [-1, 32, 32, 32]               0\n",
      "           Conv2d-17           [-1, 32, 32, 32]           9,248\n",
      "        WideBasic-18           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-19           [-1, 32, 32, 32]              64\n",
      "             ReLU-20           [-1, 32, 32, 32]               0\n",
      "           Conv2d-21           [-1, 32, 32, 32]           9,248\n",
      "          Dropout-22           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-23           [-1, 32, 32, 32]              64\n",
      "             ReLU-24           [-1, 32, 32, 32]               0\n",
      "           Conv2d-25           [-1, 32, 32, 32]           9,248\n",
      "        WideBasic-26           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-27           [-1, 32, 32, 32]              64\n",
      "             ReLU-28           [-1, 32, 32, 32]               0\n",
      "           Conv2d-29           [-1, 32, 32, 32]           9,248\n",
      "          Dropout-30           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-31           [-1, 32, 32, 32]              64\n",
      "             ReLU-32           [-1, 32, 32, 32]               0\n",
      "           Conv2d-33           [-1, 32, 32, 32]           9,248\n",
      "        WideBasic-34           [-1, 32, 32, 32]               0\n",
      "      BatchNorm2d-35           [-1, 32, 32, 32]              64\n",
      "             ReLU-36           [-1, 32, 32, 32]               0\n",
      "           Conv2d-37           [-1, 64, 32, 32]          18,496\n",
      "          Dropout-38           [-1, 64, 32, 32]               0\n",
      "      BatchNorm2d-39           [-1, 64, 32, 32]             128\n",
      "             ReLU-40           [-1, 64, 32, 32]               0\n",
      "           Conv2d-41           [-1, 64, 16, 16]          36,928\n",
      "           Conv2d-42           [-1, 64, 16, 16]           2,112\n",
      "        WideBasic-43           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-44           [-1, 64, 16, 16]             128\n",
      "             ReLU-45           [-1, 64, 16, 16]               0\n",
      "           Conv2d-46           [-1, 64, 16, 16]          36,928\n",
      "          Dropout-47           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-48           [-1, 64, 16, 16]             128\n",
      "             ReLU-49           [-1, 64, 16, 16]               0\n",
      "           Conv2d-50           [-1, 64, 16, 16]          36,928\n",
      "        WideBasic-51           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-52           [-1, 64, 16, 16]             128\n",
      "             ReLU-53           [-1, 64, 16, 16]               0\n",
      "           Conv2d-54           [-1, 64, 16, 16]          36,928\n",
      "          Dropout-55           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-56           [-1, 64, 16, 16]             128\n",
      "             ReLU-57           [-1, 64, 16, 16]               0\n",
      "           Conv2d-58           [-1, 64, 16, 16]          36,928\n",
      "        WideBasic-59           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-60           [-1, 64, 16, 16]             128\n",
      "             ReLU-61           [-1, 64, 16, 16]               0\n",
      "           Conv2d-62           [-1, 64, 16, 16]          36,928\n",
      "          Dropout-63           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-64           [-1, 64, 16, 16]             128\n",
      "             ReLU-65           [-1, 64, 16, 16]               0\n",
      "           Conv2d-66           [-1, 64, 16, 16]          36,928\n",
      "        WideBasic-67           [-1, 64, 16, 16]               0\n",
      "      BatchNorm2d-68           [-1, 64, 16, 16]             128\n",
      "             ReLU-69           [-1, 64, 16, 16]               0\n",
      "           Conv2d-70          [-1, 128, 16, 16]          73,856\n",
      "          Dropout-71          [-1, 128, 16, 16]               0\n",
      "      BatchNorm2d-72          [-1, 128, 16, 16]             256\n",
      "             ReLU-73          [-1, 128, 16, 16]               0\n",
      "           Conv2d-74            [-1, 128, 8, 8]         147,584\n",
      "           Conv2d-75            [-1, 128, 8, 8]           8,320\n",
      "        WideBasic-76            [-1, 128, 8, 8]               0\n",
      "      BatchNorm2d-77            [-1, 128, 8, 8]             256\n",
      "             ReLU-78            [-1, 128, 8, 8]               0\n",
      "           Conv2d-79            [-1, 128, 8, 8]         147,584\n",
      "          Dropout-80            [-1, 128, 8, 8]               0\n",
      "      BatchNorm2d-81            [-1, 128, 8, 8]             256\n",
      "             ReLU-82            [-1, 128, 8, 8]               0\n",
      "           Conv2d-83            [-1, 128, 8, 8]         147,584\n",
      "        WideBasic-84            [-1, 128, 8, 8]               0\n",
      "      BatchNorm2d-85            [-1, 128, 8, 8]             256\n",
      "             ReLU-86            [-1, 128, 8, 8]               0\n",
      "           Conv2d-87            [-1, 128, 8, 8]         147,584\n",
      "          Dropout-88            [-1, 128, 8, 8]               0\n",
      "      BatchNorm2d-89            [-1, 128, 8, 8]             256\n",
      "             ReLU-90            [-1, 128, 8, 8]               0\n",
      "           Conv2d-91            [-1, 128, 8, 8]         147,584\n",
      "        WideBasic-92            [-1, 128, 8, 8]               0\n",
      "      BatchNorm2d-93            [-1, 128, 8, 8]             256\n",
      "             ReLU-94            [-1, 128, 8, 8]               0\n",
      "           Conv2d-95            [-1, 128, 8, 8]         147,584\n",
      "          Dropout-96            [-1, 128, 8, 8]               0\n",
      "      BatchNorm2d-97            [-1, 128, 8, 8]             256\n",
      "             ReLU-98            [-1, 128, 8, 8]               0\n",
      "           Conv2d-99            [-1, 128, 8, 8]         147,584\n",
      "       WideBasic-100            [-1, 128, 8, 8]               0\n",
      "     BatchNorm2d-101            [-1, 128, 8, 8]             256\n",
      "            ReLU-102            [-1, 128, 8, 8]               0\n",
      "AdaptiveAvgPool2d-103            [-1, 128, 1, 1]               0\n",
      "         Flatten-104                  [-1, 128]               0\n",
      "          Linear-105                   [-1, 10]           1,290\n",
      "================================================================\n",
      "Total params: 1,469,642\n",
      "Trainable params: 1,469,642\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.01\n",
      "Forward/backward pass size (MB): 17.06\n",
      "Params size (MB): 5.61\n",
      "Estimated Total Size (MB): 22.68\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(WideResNet(), input_size=(3,32,32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
